#ifndef NCCL_DEVICE_SYMMETRIC_H_
#define NCCL_DEVICE_SYMMETRIC_H_

#include "nccl.h"
#include "nccl_common.h"
#include "bitops.h"

constexpr int ncclSymMaxBlocks = 64;
constexpr int ncclSymMaxThreads = 512;
constexpr int ncclSymLLMaxEltSize = 64;

constexpr __host__ __device__ int ncclSymLLMaxSlots(int eltSize = ncclSymLLMaxEltSize) {
  return ncclSymMaxThreads*ncclSymLLMaxEltSize/eltSize;
}

constexpr __host__ __device__ int ncclSymLLEpochSize(int nRanks) {
  return /*LL Overhead*/2 * maxval(ncclSymMaxThreads*nRanks*8, ncclSymLLMaxSlots(ncclSymLLMaxEltSize)*ncclSymLLMaxEltSize);
}

struct alignas(16) ncclSymDevBase {
  uint32_t llEpoch[ncclSymMaxBlocks];
  uint32_t barEpochMc[ncclSymMaxBlocks], barEpochUc[ncclSymMaxBlocks];
  uint32_t barInboxMc[ncclSymMaxBlocks];
  uint32_t barInboxPerPeer[];

  static constexpr size_t size(int nRanks) {
    return sizeof(ncclSymDevBase) +
           alignUp(ncclSymMaxBlocks*nRanks*sizeof(uint32_t), 16) +
           ncclSymMaxBlocks * /*epochs=*/2 * ncclSymLLEpochSize(nRanks);
  }
};

static __device__ uint4* ncclSymDevBase_getLLBuf(struct ncclSymDevBase* base, int nRanks, int block, uint32_t epoch) {
  // Get pointer to buffer trailing the header struct.
  char* ans = (char*)(base + 1);
  // Skip over barInboxPerPeer[]
  ans += alignUp(ncclSymMaxBlocks*nRanks*sizeof(uint32_t), 16);
  // Skip to our block
  int epochSize = ncclSymLLEpochSize(nRanks);
  ans += block * /*epochs=*/2 * epochSize;
  ans += (epoch & 1)*epochSize;
  return (uint4*)ans;
}

struct ncclSymDevComm {
  ncclSymDevBase* base;
  ncclSymDevBase* baseMc;
  uint32_t stride4G;
  int nRanks, rank;
  uint32_t nRanks_rcp32; // idivRcp32(nRanks)
};

struct alignas(16) ncclSymDevArgs {
  struct ncclSymDevComm comm;
  int rootRank;
  uint64_t redOpArg; // must be collectively uniform
  size_t nElts;
  char* input;
  char* output;
};

enum ncclSymKernelId {
  ncclSymKernelId_AllReduce_AGxLL_R,
  ncclSymKernelId_AllReduce_AGxLLMC_R,
  ncclSymKernelId_AllReduce_RSxLD_AGxST,
  ncclSymKernelId_AllReduce_RSxLDMC_AGxSTMC,

  ncclSymKernelId_AllGather_LL,
  ncclSymKernelId_AllGather_LLMC,
  ncclSymKernelId_AllGather_ST,
  ncclSymKernelId_AllGather_STMC,

  ncclSymKernelId_ReduceScatter_LL,
  ncclSymKernelId_ReduceScatter_LD,
  ncclSymKernelId_ReduceScatter_LDMC,

  ncclSymKernelId_Count
};

bool ncclSymImplemented(ncclFunc_t fn, int/*ncclDevRedOp_t*/ red, ncclDataType_t ty);

ncclResult_t ncclSymPickKernel(struct ncclComm* comm, ncclFunc_t fn, int/*ncclDevRedOp_t*/ red, ncclDataType_t ty, size_t nElts, float* estTimeUs, ncclSymKernelId* kernelId, int* nBlocks, int* nWarps);

// Generated by src/device/symmetric/generate.py
extern int const ncclSymKernelCount;
extern void* const ncclSymKernelList[];
void* ncclSymGetKernelPtr(ncclSymKernelId kernelId, int/*ncclDevRedOp_t*/ red, ncclDataType_t ty);
const char* ncclSymKernelIdToString(int kernelId);

#endif
